{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.0.0-beta1\n",
      "sys.version_info(major=3, minor=5, micro=3, releaselevel='final', serial=0)\n",
      "matplotlib 3.0.3\n",
      "numpy 1.16.4\n",
      "pandas 0.24.2\n",
      "sklearn 0.21.2\n",
      "tensorflow 2.0.0-beta1\n",
      "tensorflow.python.keras.api._v2.keras 2.2.4-tf\n"
     ]
    }
   ],
   "source": [
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import sklearn\n",
    "import pandas as pd\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "import tensorflow as tf\n",
    "\n",
    "from tensorflow import keras\n",
    "\n",
    "print(tf.__version__)\n",
    "print(sys.version_info)\n",
    "for module in mpl, np, pd, sklearn, tf, keras:\n",
    "    print(module.__name__, module.__version__)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4\n",
      "4\n"
     ]
    }
   ],
   "source": [
    "tf.debugging.set_log_device_placement(True)\n",
    "gpus = tf.config.experimental.list_physical_devices('GPU')\n",
    "for gpu in gpus:\n",
    "    tf.config.experimental.set_memory_growth(gpu, True)\n",
    "print(len(gpus))\n",
    "logical_gpus = tf.config.experimental.list_logical_devices('GPU')\n",
    "print(len(logical_gpus))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5000, 28, 28) (5000,)\n",
      "(55000, 28, 28) (55000,)\n",
      "(10000, 28, 28) (10000,)\n"
     ]
    }
   ],
   "source": [
    "fashion_mnist = keras.datasets.fashion_mnist\n",
    "(x_train_all, y_train_all), (x_test, y_test) = fashion_mnist.load_data()\n",
    "x_valid, x_train = x_train_all[:5000], x_train_all[5000:]\n",
    "y_valid, y_train = y_train_all[:5000], y_train_all[5000:]\n",
    "\n",
    "print(x_valid.shape, y_valid.shape)\n",
    "print(x_train.shape, y_train.shape)\n",
    "print(x_test.shape, y_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "scaler = StandardScaler()\n",
    "x_train_scaled = scaler.fit_transform(\n",
    "    x_train.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28, 1)\n",
    "x_valid_scaled = scaler.transform(\n",
    "    x_valid.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28, 1)\n",
    "x_test_scaled = scaler.transform(\n",
    "    x_test.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28, 1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Executing op TensorSliceDataset in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op ShuffleDataset in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op RepeatDataset in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op BatchDatasetV2 in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op PrefetchDataset in device /job:localhost/replica:0/task:0/device:CPU:0\n"
     ]
    }
   ],
   "source": [
    "def make_dataset(images, labels, epochs, batch_size, shuffle=True):\n",
    "    dataset = tf.data.Dataset.from_tensor_slices((images, labels))\n",
    "    if shuffle:\n",
    "        dataset = dataset.shuffle(10000)\n",
    "    dataset = dataset.repeat(epochs).batch(batch_size).prefetch(50)\n",
    "    return dataset\n",
    "\n",
    "# batch_size_per_replica = 256\n",
    "# batch_size = batch_size_per_replica * len(logical_gpus)\n",
    "batch_size = 256\n",
    "epochs = 100\n",
    "train_dataset = make_dataset(x_train_scaled, y_train, epochs, batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Executing op RandomUniform in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op Sub in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op Mul in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op Add in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarIsInitializedOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op LogicalNot in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op Assert in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op AssignVariableOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op ReadVariableOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarIsInitializedOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op LogicalNot in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op Assert in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op AssignVariableOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:2\n",
      "Executing op VarIsInitializedOp in device /job:localhost/replica:0/task:0/device:GPU:2\n",
      "Executing op LogicalNot in device /job:localhost/replica:0/task:0/device:GPU:2\n",
      "Executing op Assert in device /job:localhost/replica:0/task:0/device:GPU:2\n",
      "Executing op AssignVariableOp in device /job:localhost/replica:0/task:0/device:GPU:2\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:3\n",
      "Executing op VarIsInitializedOp in device /job:localhost/replica:0/task:0/device:GPU:3\n",
      "Executing op LogicalNot in device /job:localhost/replica:0/task:0/device:GPU:3\n",
      "Executing op Assert in device /job:localhost/replica:0/task:0/device:GPU:3\n",
      "Executing op AssignVariableOp in device /job:localhost/replica:0/task:0/device:GPU:3\n",
      "Executing op Fill in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:2\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:3\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:2\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:3\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:2\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:3\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:2\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:3\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:2\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:3\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:2\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:3\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:2\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:3\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:2\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:3\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:2\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:3\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:2\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:3\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:2\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:3\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:2\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:3\n"
     ]
    }
   ],
   "source": [
    "strategy = tf.distribute.MirroredStrategy()\n",
    "\n",
    "with strategy.scope():\n",
    "    model = keras.models.Sequential()\n",
    "    model.add(keras.layers.Conv2D(filters=128, kernel_size=3,\n",
    "                                  padding='same',\n",
    "                                  activation='relu',\n",
    "                                  input_shape=(28, 28, 1)))\n",
    "    model.add(keras.layers.Conv2D(filters=128, kernel_size=3,\n",
    "                                  padding='same',\n",
    "                                  activation='relu'))\n",
    "    model.add(keras.layers.MaxPool2D(pool_size=2))\n",
    "    model.add(keras.layers.Conv2D(filters=256, kernel_size=3,\n",
    "                                  padding='same',\n",
    "                                  activation='relu'))\n",
    "    model.add(keras.layers.Conv2D(filters=256, kernel_size=3,\n",
    "                                  padding='same',\n",
    "                                  activation='relu'))\n",
    "    model.add(keras.layers.MaxPool2D(pool_size=2))\n",
    "    model.add(keras.layers.Conv2D(filters=512, kernel_size=3,\n",
    "                                  padding='same',\n",
    "                                  activation='relu'))\n",
    "    model.add(keras.layers.Conv2D(filters=512, kernel_size=3,\n",
    "                                  padding='same',\n",
    "                                  activation='relu'))\n",
    "    model.add(keras.layers.MaxPool2D(pool_size=2))\n",
    "    model.add(keras.layers.Flatten())\n",
    "    model.add(keras.layers.Dense(512, activation='relu'))\n",
    "    model.add(keras.layers.Dense(10, activation=\"softmax\"))\n",
    "\n",
    "    model.compile(loss=\"sparse_categorical_crossentropy\",\n",
    "                  optimizer = \"sgd\",\n",
    "                  metrics = [\"accuracy\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "conv2d (Conv2D)              (None, 28, 28, 128)       1280      \n",
      "_________________________________________________________________\n",
      "conv2d_1 (Conv2D)            (None, 28, 28, 128)       147584    \n",
      "_________________________________________________________________\n",
      "max_pooling2d (MaxPooling2D) (None, 14, 14, 128)       0         \n",
      "_________________________________________________________________\n",
      "conv2d_2 (Conv2D)            (None, 14, 14, 256)       295168    \n",
      "_________________________________________________________________\n",
      "conv2d_3 (Conv2D)            (None, 14, 14, 256)       590080    \n",
      "_________________________________________________________________\n",
      "max_pooling2d_1 (MaxPooling2 (None, 7, 7, 256)         0         \n",
      "_________________________________________________________________\n",
      "conv2d_4 (Conv2D)            (None, 7, 7, 512)         1180160   \n",
      "_________________________________________________________________\n",
      "conv2d_5 (Conv2D)            (None, 7, 7, 512)         2359808   \n",
      "_________________________________________________________________\n",
      "max_pooling2d_2 (MaxPooling2 (None, 3, 3, 512)         0         \n",
      "_________________________________________________________________\n",
      "flatten (Flatten)            (None, 4608)              0         \n",
      "_________________________________________________________________\n",
      "dense (Dense)                (None, 512)               2359808   \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 10)                5130      \n",
      "=================================================================\n",
      "Total params: 6,939,018\n",
      "Trainable params: 6,939,018\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Executing op ExperimentalRebatchDataset in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op ExperimentalAutoShardDataset in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op OptimizeDataset in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op ModelDataset in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op MultiDeviceIterator in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op MultiDeviceIteratorInit in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op MultiDeviceIteratorToStringHandle in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op GeneratorDataset in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op GeneratorDataset in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op GeneratorDataset in device /job:localhost/replica:0/task:0/device:GPU:2\n",
      "Executing op GeneratorDataset in device /job:localhost/replica:0/task:0/device:GPU:3\n",
      "Executing op PrefetchDataset in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op AnonymousIteratorV2 in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op MakeIterator in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op PrefetchDataset in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op AnonymousIteratorV2 in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op MakeIterator in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op PrefetchDataset in device /job:localhost/replica:0/task:0/device:GPU:2\n",
      "Executing op AnonymousIteratorV2 in device /job:localhost/replica:0/task:0/device:GPU:2\n",
      "Executing op MakeIterator in device /job:localhost/replica:0/task:0/device:GPU:2\n",
      "Executing op PrefetchDataset in device /job:localhost/replica:0/task:0/device:GPU:3\n",
      "Executing op AnonymousIteratorV2 in device /job:localhost/replica:0/task:0/device:GPU:3\n",
      "Executing op MakeIterator in device /job:localhost/replica:0/task:0/device:GPU:3\n",
      "Train on 214 steps\n",
      "Epoch 1/10\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op AssignVariableOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op ReadVariableOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op AssignVariableOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:2\n",
      "Executing op AssignVariableOp in device /job:localhost/replica:0/task:0/device:GPU:2\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:3\n",
      "Executing op AssignVariableOp in device /job:localhost/replica:0/task:0/device:GPU:3\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:2\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:3\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:2\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:3\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:2\n",
      "Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:3\n",
      "Executing op __inference_initialize_variables_2831 in device <unspecified>\n",
      "Executing op __inference_distributed_function_4215 in device <unspecified>\n",
      "214/214 [==============================] - 47s 221ms/step - loss: 1.8203 - accuracy: 0.4841\n",
      "Epoch 2/10\n",
      "214/214 [==============================] - 7s 33ms/step - loss: 0.7559 - accuracy: 0.7126\n",
      "Epoch 3/10\n",
      "214/214 [==============================] - 7s 33ms/step - loss: 0.6088 - accuracy: 0.7705\n",
      "Epoch 4/10\n",
      "214/214 [==============================] - 7s 33ms/step - loss: 0.5282 - accuracy: 0.8030\n",
      "Epoch 5/10\n",
      "214/214 [==============================] - 7s 33ms/step - loss: 0.4780 - accuracy: 0.8218\n",
      "Epoch 6/10\n",
      "214/214 [==============================] - 7s 33ms/step - loss: 0.4450 - accuracy: 0.8354\n",
      "Epoch 7/10\n",
      "214/214 [==============================] - 7s 33ms/step - loss: 0.4133 - accuracy: 0.8486\n",
      "Epoch 8/10\n",
      "214/214 [==============================] - 7s 33ms/step - loss: 0.3924 - accuracy: 0.8551\n",
      "Epoch 9/10\n",
      "214/214 [==============================] - 7s 33ms/step - loss: 0.3760 - accuracy: 0.8617\n",
      "Epoch 10/10\n",
      "214/214 [==============================] - 7s 33ms/step - loss: 0.3577 - accuracy: 0.8700\n",
      "Executing op DestroyResourceOp in device /job:localhost/replica:0/task:0/device:CPU:0\n",
      "Executing op DeleteIterator in device /job:localhost/replica:0/task:0/device:GPU:3\n",
      "Executing op DeleteIterator in device /job:localhost/replica:0/task:0/device:GPU:2\n",
      "Executing op DeleteIterator in device /job:localhost/replica:0/task:0/device:GPU:1\n",
      "Executing op DeleteIterator in device /job:localhost/replica:0/task:0/device:GPU:0\n"
     ]
    }
   ],
   "source": [
    "history = model.fit(train_dataset,\n",
    "                    steps_per_epoch = x_train_scaled.shape[0] // batch_size,\n",
    "                    epochs=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
